import cdtools
from matplotlib import pyplot as plt
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import center_probe, center_probe_fourier
from cdtools.tools import plotting as p


def clean_and_crop_dataset(dataset, dead_idx=[], intensity_threshold=0):
    
    # Sometimes the translations have more data than the patterns, so we need to
    # cut it down to size
    dataset.translations = dataset.translations[:dataset.patterns.shape[0],:]
    dataset.intensities = dataset.intensities[:dataset.patterns.shape[0]]
    

    # We remove the weak images (low intensity from the FEL)
    dead_idx = [-1] + dead_idx +  t.argwhere(dataset.intensities<=intensity_threshold).numpy().ravel().tolist()

    dead_idx = sorted(dead_idx)
    dead_idx += [None]

    live_ranges = list(zip(np.array(dead_idx[:-1])+1,dead_idx[1:]))
    dataset.patterns = t.cat([dataset.patterns[start:end] for start, end
                              in live_ranges], dim=0)
    dataset.translations = t.cat([dataset.translations[start:end] for start, end
                                  in live_ranges], dim=0)
    dataset.intensities = t.cat([dataset.intensities[start:end] for start, end
                                 in live_ranges], dim=0)

    background_estimate = \
        t.median(dataset.patterns[:,:100,:], dim=1, keepdim=True)[0]
        
    dataset.patterns -= background_estimate

    dataset.patterns = t.nn.functional.pad(dataset.patterns,
                                           (-100,)*4)
    dataset.mask = t.nn.functional.pad(dataset.mask,
                                           (-100,)*4)

    return dataset


base_folder = '/das/work/p21/p21561/projects/20250822_Redoing_Magnetic_RPI_DiProI/'

print('Processing the single-shot left-hand circularly polarized ptychography dataset')

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(base_folder + 'reprocessed_cxis/ptycho_lcp_035.cxi', cut_zeros=False)

dataset = clean_and_crop_dataset(dataset, intensity_threshold=36)
dataset.to_cxi('data/magnetic_ptycho_lcp.cxi')

dataset.inspect()

print('Processing the single-shot left-hand circularly polarized RPI dataset')

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(base_folder + 'reprocessed_cxis/rpi_lcp_037.cxi', cut_zeros=False)

dataset = clean_and_crop_dataset(dataset, intensity_threshold=36)
dataset.to_cxi('data/magnetic_rpi_lcp.cxi')

dataset.inspect()

print('Processing the single-shot right-hand circularly polarized ptychography dataset')

dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(base_folder + 'reprocessed_cxis/ptycho_rcp_042.cxi', cut_zeros=False)

dataset = clean_and_crop_dataset(dataset, intensity_threshold=36)
dataset.to_cxi('data/magnetic_ptycho_rcp.cxi')

dataset.inspect()

plt.show()
